CV Week: Итоговое задание¶
На лекции и семинаре мы разбирали как дистиллировать многошаговую диффузионную модель в малошагового студента, и тем самым будет работать на порядок быстрее учителя.
Один из подходов, который мы разбирали Consistency Distillation. В этом задании, мы закрепим материал, который был на лекции и семинаре и реализуем этот фреймворк, затрагивая различные нюансы.
В этом задании мы будем дистиллировать модель Stable Diffusion 1.5 (SD1.5) для генерации картинок по текстовому описанию.
Вам предстоит выполнить 8 небольших заданий, которые приведут нас к неплохой модели для генерации картинок за 4 шага, работая в органиченных условиях колаба.
# # torch 2.4.1+cu124
!pip install diffusers==0.30.2 peft==0.8.2 huggingface_hub==0.23.4
Collecting diffusers==0.30.2 Downloading diffusers-0.30.2-py3-none-any.whl.metadata (18 kB) Collecting peft==0.8.2 Downloading peft-0.8.2-py3-none-any.whl.metadata (25 kB) Collecting huggingface_hub==0.23.4 Downloading huggingface_hub-0.23.4-py3-none-any.whl.metadata (12 kB) Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (7.0.0) Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (3.15.1) Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (1.26.4) Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (2024.5.15) Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (2.32.3) Requirement already satisfied: safetensors>=0.3.1 in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (0.4.5) Requirement already satisfied: Pillow in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (10.3.0) Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (21.3) Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (5.9.3) Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (6.0.2) Requirement already satisfied: torch>=1.13.0 in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (2.4.0) Requirement already satisfied: transformers in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (4.46.3) Requirement already satisfied: tqdm in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (4.66.4) Requirement already satisfied: accelerate>=0.21.0 in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (1.1.1) Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub==0.23.4) (2024.6.0) Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub==0.23.4) (4.12.2) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=20.0->peft==0.8.2) (3.1.2) Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.8.2) (1.13.3) Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.8.2) (3.3) Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.8.2) (3.1.4) Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.10/site-packages (from importlib-metadata->diffusers==0.30.2) (3.19.2) Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers==0.30.2) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers==0.30.2) (3.7) Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers==0.30.2) (1.26.18) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers==0.30.2) (2024.6.2) Requirement already satisfied: tokenizers<0.21,>=0.20 in /opt/conda/lib/python3.10/site-packages (from transformers->peft==0.8.2) (0.20.3) Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.13.0->peft==0.8.2) (2.1.5) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.13.0->peft==0.8.2) (1.3.0) Downloading diffusers-0.30.2-py3-none-any.whl (2.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.6/2.6 MB 29.4 MB/s eta 0:00:00a 0:00:01 Downloading peft-0.8.2-py3-none-any.whl (183 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 183.4/183.4 kB 14.9 MB/s eta 0:00:00 Downloading huggingface_hub-0.23.4-py3-none-any.whl (402 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 402.6/402.6 kB 29.2 MB/s eta 0:00:00 Installing collected packages: huggingface_hub, diffusers, peft Attempting uninstall: huggingface_hub Found existing installation: huggingface-hub 0.26.2 Uninstalling huggingface-hub-0.26.2: Successfully uninstalled huggingface-hub-0.26.2 Successfully installed diffusers-0.30.2 huggingface_hub-0.23.4 peft-0.8.2
from tqdm.auto import tqdm
import csv
import os
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, LCMScheduler, UNet2DConditionModel, DDIMScheduler
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
import functools
%matplotlib inline
import matplotlib.pyplot as plt
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
0it [00:00, ?it/s]
MODEL_DIR = '/kaggle/input/cv-week-2024/pytorch/default/1/cm_unet_3/'
#---------------------
# Visualization utils
#---------------------
def visualize_images(images):
assert len(images) == 4
plt.figure(figsize=(12, 3))
for i, image in enumerate(images):
plt.subplot(1, 4, i+1)
plt.imshow(image)
plt.axis('off')
plt.subplots_adjust(wspace=-0.01, hspace=-0.01)
#--------------
# Tensor utils
#--------------
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
#---------------
# Dataset utils
#---------------
class COCODataset(torch.utils.data.Dataset):
def __init__(self, root_dir, subset_name="train2014_5k", transform=None, max_cnt=None):
"""
Arguments:
root_dir (string): Директория с картинками
transform (callable, optional): преобразования, применимые к картинкам
"""
self.root_dir = root_dir
self.transform = transform
self.extensions = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
sample_dir = os.path.join(root_dir, subset_name)
# Собираем пути до картинок
self.samples = sorted(
[
os.path.join(sample_dir, fname)
for fname in os.listdir(sample_dir)
if fname[-4:] in self.extensions
],
key=lambda x: x.split("/")[-1].split(".")[0],
)
self.samples = (
self.samples if max_cnt is None else self.samples[:max_cnt]
) #
# Собираем промпты
self.captions = {}
with open(
os.path.join(root_dir, f"{subset_name}.csv"), newline="\n"
) as csvfile:
spamreader = csv.reader(csvfile, delimiter=",")
for i, row in enumerate(spamreader):
if i == 0:
continue
self.captions[row[1]] = row[2]
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample_path = self.samples[idx]
sample = Image.open(sample_path).convert("RGB")
if self.transform:
sample = self.transform(sample)
return {
"image": sample,
"text": self.captions[os.path.basename(sample_path)],
"idxs": idx,
}
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
# Проверяем, что все компоненты модели в FP16 и на cuda
assert pipe.unet.dtype == torch.float16 and pipe.unet.device.type =='cuda'
assert pipe.vae.dtype == torch.float16 and pipe.vae.device.type == 'cuda'
assert pipe.text_encoder.dtype == torch.float16 and pipe.text_encoder.device.type == 'cuda'
# Заменяем дефолтный сэмплер на DDIM
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler.timesteps = pipe.scheduler.timesteps.cuda()
pipe.scheduler.alphas_cumprod = pipe.scheduler.alphas_cumprod.cuda()
# Отдельно извлечем модель учителя, которую потом будем дистиллировать
teacher_unet = pipe.unet
Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]
Теперь сгенерируем картинки за 50 шагов. Вам нужно написать вызов pipe и передать в него промпт, число шагов генерации, генератор случайных чисел, гайденс скейл и указать, чтобы сгенерировалось 4 картинки на промпт.
Промпты¶
validation_prompts = [
"A sad puppy with large eyes",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
"A girl with pale blue hair and a cami tank top",
"A lighthouse in a giant wave, origami style",
"belle epoque, christmas, red house in the forest, photo realistic, 8k",
"A small cactus with a happy face in the Sahara desert",
"Green commercial building with refrigerator and refrigeration units outside",
]
SD 50 steps¶
generator = torch.Generator('cuda').manual_seed(1)
guidance_scale = 7.5
images = pipe(validation_prompts[0], num_images_per_prompt=4,
num_inference_steps=50,
guidance_scale=guidance_scale,
generator=generator).images
visualize_images(images)
0%| | 0/50 [00:00<?, ?it/s]
SD 4 steps¶
generator = torch.Generator('cuda').manual_seed(1)
guidance_scale = 7.5
images = pipe(validation_prompts[0], num_images_per_prompt=4,
num_inference_steps=4,
guidance_scale=guidance_scale,
generator=generator).images
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
Датасет¶
!wget https://storage.yandexcloud.net/yandex-research/train2014_5k.tar.gz
!tar -xzf train2014_5k.tar.gz
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
lambda x: 2 * x - 1,
])
dataset = COCODataset(".",
subset_name="train2014_5k",
transform=transform,
# max_cnt=2500
)
assert len(dataset) == 5000 # 2500
batch_size = 8 # Рекоммендуемы размер батча на Colab
train_dataloader = torch.utils.data.DataLoader(
dataset=dataset, shuffle=True, batch_size=batch_size, drop_last=True
)
@torch.no_grad()
def prepare_batch(batch, pipe):
"""
Предобработка батча картинок и текстовых промптов.
Маппим картинки в латентное пространство VAE.
Извлекаем эмбеды промптов с помощью текстового энкодера.
Params:
Return:
latents: torch.Tensor([B, 4, 64, 64], dtype=torch.float16)
prompt_embeds: torch.Tensor([B, 77, D], dtype=torch.float16)
"""
# Токенизируем промпты
text_inputs = pipe.tokenizer(
batch['text'],
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
# Извлекаем эмбеды промптов с помощью текстового энкодера
prompt_embeds = pipe.text_encoder(text_inputs.input_ids.cuda())[0]
# Переводим картинки в латентное пространство VAE
image = batch['image'].to("cuda", dtype=torch.float16)
latents = pipe.vae.encode(image).latent_dist.sample()
latents = latents * pipe.vae.config.scaling_factor
return latents, prompt_embeds
UNet¶
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
unet.train()
unet.to('cuda')
assert unet.dtype == torch.float32
assert unet.training
LoRA¶
# Указываем к каким слоям модели мы будет добавлять адаптеры.
lora_modules = [
"to_q", "to_k", "to_v", "to_out.0", "proj_in", "proj_out",
"ff.net.0.proj", "ff.net.2", "conv1", "conv2", "conv_shortcut",
"downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj"
]
lora_config = LoraConfig(
r=64, # задает ранг у матриц A и B в LoRA.
target_modules=lora_modules
)
CT UNet¶
# Создаем обертку исходной UNet модели с LoRA адаптерами, используя библиотеку PEFT
cm_unet = get_peft_model(unet, lora_config, adapter_name="ct")
# Включаем gradient checkpointing - важная техника для экономии памяти во время обучения
cm_unet.enable_gradient_checkpointing()
# Создаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
# Задаем лосс функцию для CM обжектива. В базовом варианте разумно взять L2
# По умолчанию, она уже выдает усредненное значение по всем размерностям
mse_loss = torch.nn.functional.mse_loss
def ddim_solver_step(model_output, x_t, t, s, scheduler):
"""
Шаг DDIM солвера для VP процесса зашумления и eps-prediction модели
params:
model_output: torch.Tensor[B, 4, 64, 64] - предсказание модели - шум eps
x_t: torch.Tensor[B, 4, 64, 64] - сэмплы на шаге t
t: torch.Tensor[B] - номер текущего шага
s: torch.Tensor[B] - номер следующего шага
scheduler: DDIMScheduler - расписание диффузионного процесса, чтобы получить alpha и sigma
"""
alphas = torch.sqrt(scheduler.alphas_cumprod).to(x_t.device)
sigmas = torch.sqrt(1 - scheduler.alphas_cumprod).to(x_t.device)
sigmas_s = extract_into_tensor(sigmas, s, x_t.shape)
alphas_s = extract_into_tensor(alphas, s, x_t.shape)
sigmas_t = extract_into_tensor(sigmas, t, x_t.shape)
alphas_t = extract_into_tensor(alphas, t, x_t.shape)
# Выставляем крайние значения alpha и sigma, чтобы выполнялись граничные условия
alphas_s[s == 0] = 1.0
sigmas_s[s == 0] = 0.0
alphas_t[t == 0] = 1.0
sigmas_t[t == 0] = 0.0
x_0 = (x_t - model_output * sigmas_t) / alphas_t # x0 оценка на шаге t
x_s = alphas_s * x_0 + sigmas_s * model_output # Переход на шаг s
return x_s
Реализация процесса зашумления (q sample)¶
def q_sample(x, t, scheduler, noise=None):
alphas = torch.sqrt(scheduler.alphas_cumprod).to(x.device)
sigmas = torch.sqrt(1 - scheduler.alphas_cumprod).to(x.device)
sigmas_t = extract_into_tensor(sigmas, t, x.shape)
alphas_t = extract_into_tensor(alphas, t, x.shape)
if noise is None:
noise = torch.randn_like(x)
x_t = alphas_t * x + sigmas_t * noise
return x_t
Consistency Training¶
Обучение консистенси моделей без учителя называется Consistency Training (CT). В таком случае CM можно рассматривать как отдельный вид генеративных моделей. Давайте начнем именно с этого подхода и обучим нашу первую консистенси модель на базе SD1.5.
Задание №3¶
Задание №3.1 (0.5 балла, сдается в контесте)¶
В консиcтенси дистилляции модель учителя используется для получения второй точки на траектории ODE. Можем ли мы попробовать оценить соседнюю точку аналитически?
Вам предлагается вывести это самим, используя формулу DDIM шага выше и вспомнив, как мы оцениваем скор функции в denoising score matching-e:
def get_xs_from_xt_naive(
x_0, x_t, t, s, # Не все эти аргументы могут быть вам нужны
scheduler,
noise=None,
**kwargs
):
"""
Получение точки x_s в CT режиме, т.е., аналитически.
"""
if x_t is None:
return q_sample(x_0, s, scheduler, noise)
alphas = torch.sqrt(scheduler.alphas_cumprod).to(x_0.device)
sigmas = torch.sqrt(1 - scheduler.alphas_cumprod).to(x_0.device)
sigmas_s = extract_into_tensor(sigmas, s, x_0.shape)
alphas_s = extract_into_tensor(alphas, s, x_0.shape)
sigmas_t = extract_into_tensor(sigmas, t, x_t.shape)
alphas_t = extract_into_tensor(alphas, t, x_t.shape)
alphas_s[s == 0] = 1.0
sigmas_s[s == 0] = 0.0
alphas_t[t == 0] = 1.0
sigmas_t[t == 0] = 0.0
pred_noise = (x_t - alphas_t * x_0) / sigmas_t
x_s = (
alphas_s / alphas_t * (x_t - sigmas_t * pred_noise) +
sigmas_s * pred_noise
)
return x_s
def cm_loss_template(
latents, prompt_embeds, # батч латентов и текстовых эмбедов
unet, scheduler,
# Функции, которые будем постепенно менять из задания к заданию
loss_fn: callable,
get_boundary_timesteps: callable,
get_xs_from_xt: callable,
num_timesteps=1000,
step_size=20, # Указываем с каким интервалом берем шаги s и t.
):
# Сэмплируем случайные шаги t для каждого элемента батча t ~ U[step_size-1, 999]
assert num_timesteps == 1000
num_intervals = num_timesteps // step_size
index = torch.randint(1, num_intervals, (len(latents),), device=latents.device).long() # [1, num_intervals]
t = step_size * index - 1
s = torch.clamp(t - step_size, min=0)
boundary_timesteps = get_boundary_timesteps(s, num_timesteps=num_timesteps)
# Сэмплируем x_t
noise = torch.randn_like(latents)
x_t = q_sample(latents, t, scheduler, noise)
noise_pred = unet(x_t.float(), t,
encoder_hidden_states=prompt_embeds.float(),
).sample
# Получаем оценку в граничной точке для x_t
boundary_pred = ddim_solver_step(noise_pred, x_t, t, boundary_timesteps, scheduler)
# Получаем сэмпл x_s из x_t
x_s = get_xs_from_xt(
latents, x_t, t, s,
scheduler,
prompt_embeds=prompt_embeds,
noise=noise,
)
# Предсказание "таргет моделью"
with torch.no_grad():
target_noise_pred = unet(x_s, s, encoder_hidden_states=prompt_embeds).sample
# Получаем оценку в граничной точке для x_s
boundary_target = ddim_solver_step(target_noise_pred, x_s, s, boundary_timesteps, scheduler)
loss = loss_fn(boundary_pred, boundary_target)
return loss
Boundary timestamps¶
def get_zero_boundary_timesteps(t, **kwargs):
"""
Определяем шаги где будут срабатывать граничные условия.
Для классических СM это t=0.
"""
return torch.zeros_like(t)
import functools
ct_loss = functools.partial(
cm_loss_template,
loss_fn=mse_loss,
get_boundary_timesteps=get_zero_boundary_timesteps,
get_xs_from_xt=get_xs_from_xt_naive
)
assert cm_unet.active_adapter == 'ct'
Обучающий цикл¶
Вам дан код обучения модель в полной точности (FP32) c батчом 8. К сожалению, на Tesla T4 мы не влезем по памяти. Поэтому в ячейке ниже вам нужно модифицировать цикл, чтобы он работал в mixed precision FP16 и добавить gradient accumulation.
Про реализацию mixed-precision в pytorch можно перейти по ссылке: Mixed-precision обучение
Обратите внимание: вам еще нужно добавить одну строчку кода в cm_loss_template в соответствующем плейсхолдере.
Замечание: В начале обучения значения лосса должны быть в окрестности 0.0007-0.001. Ничего страшного, что лосс не падает, для CM это нормально. В конце обучения лосс может доходить до 0.005-0.01
def train_loop(model, pipe, train_dataloader, optimizer, loss_fn, num_grad_accum=1):
torch.cuda.empty_cache()
scaler = torch.amp.GradScaler()
for i, batch in enumerate(tqdm(train_dataloader)):
latents, prompt_embeds = prepare_batch(batch, pipe)
with torch.amp.autocast('cuda', dtype=torch.float16):
loss = loss_fn(latents, prompt_embeds, model, pipe.scheduler)
loss = loss / num_grad_accum
scaler.scale(loss).backward()
if (i + 1) % num_grad_accum == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
print(f"Loss: {loss.detach().item()}")
# # Обновляем параметры каждые 2 шага
# num_grad_accum = 2
# train_loop(cm_unet, pipe, train_dataloader, optimizer, ct_loss, num_grad_accum)
# cm_unet.save_pretrained('cm_unet')
Задание №5¶
Генерация с помощью обученной консистенси модели¶
Настало время погенерировать картинки с помощью нашей модели. Напомним, что мы не можем для консистенси моделей использовать DDIM и другие классические солверы для диффузии. Нам нужен специальный сэмплер для CM, который схематично изображен на картинке ниже:
@torch.no_grad()
def consistency_sampling(
pipe,
prompt,
num_inference_steps=4,
generator=None,
num_images_per_prompt=4,
guidance_scale=1
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
device = pipe._execution_device
# Извлекаем эмбеды из текстовых промптов. Реализуйте вызов pipe.encode_prompt
do_classifier_free_guidance = guidance_scale > 0
prompt_embeds, null_prompt_embeds = pipe.encode_prompt(
prompt,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
device=device
)
# assert prompt_embeds.dtype == null_prompt_embeds.dtype == torch.float16
# Настраиваем параметры scheduler-a
assert pipe.scheduler.config['timestep_spacing'] == 'trailing'
pipe.scheduler.set_timesteps(num_inference_steps)
# Создаем батч латентов из N(0,I)
latents = pipe.prepare_latents(
batch_size * num_images_per_prompt,
pipe.vae.config['latent_channels'], # Number of channels in the latent space
pipe.vae.config['sample_size'], # Use sample_size for both height and width
pipe.vae.config['sample_size'], # Use sample_size for both height and width
dtype=torch.float16,
device=device,
generator=generator
)
for i, t in enumerate(tqdm(pipe.scheduler.timesteps)):
t = torch.tensor([t] * len(latents)).to(device)
zero_t = torch.tensor([0] * len(latents)).to(device)
cond_noise_pred = pipe.unet(latents, t, prompt_embeds).sample
if do_classifier_free_guidance:
uncond_noise_pred = pipe.unet(latents, t, null_prompt_embeds).sample
noise_pred = (
uncond_noise_pred + guidance_scale * (cond_noise_pred - uncond_noise_pred)
)
else:
noise_pred = cond_noise_pred
# Получаем x_0 оценку из x_t
x_0 = ddim_solver_step(noise_pred, latents, t, zero_t, pipe.scheduler)
if i + 1 < num_inference_steps:
# Переход на следующий шаг
s = pipe.scheduler.timesteps[i+1]
s = torch.tensor([s] * len(latents)).to(device)
latents = q_sample(x_0, s, pipe.scheduler)
else:
latents = x_0
latents = latents.half()
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
do_denormalize = [True] * image.shape[0]
image = pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=do_denormalize)
return image
Попробуем сгененировать что-то нашей моделью. Можно поиграться с разными сидами и гайденс скейлами.
Референс, что примерно должно получиться на этом этапе для guidance_scale=2. Как видите, картинки стали почетче, но пока все еще так себе.

cm_unet = PeftModel.from_pretrained(
model=unet,
model_id=f'{MODEL_DIR}/ct',
config=lora_config,
adapter_name="ct"
)
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'ct'
generator = torch.Generator(device="cuda").manual_seed(1)
guidance_scale = 3
# Заменяем генерацию пайплайном на наше сэмплирование.
images = consistency_sampling(
pipe,
prompt=validation_prompts[0],
num_inference_steps=4,
generator=generator,
num_images_per_prompt=4,
guidance_scale=guidance_scale
)
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
Дистиляция¶
@torch.no_grad()
def get_xs_from_xt_with_teacher(
x_0, x_t, t, s, # Не все эти аргументы могут быть вам нужны
scheduler,
prompt_embeds,
teacher_unet,
guidance_scale,
**kwargs
):
# Делаем предсказание учителем в кондишион случае: подаем эмбеды текста
cond_noise_pred = teacher_unet(x_t, t, prompt_embeds)
# Для CFG нам нужно делать предсказания в unconditional случае.
# Для T2I моделей, мы будем это моделировать предсказаниями для пустого промпта ""
# Извлечем эмбеды из пустого промпта и размножить их до размера батча
uncond_input_ids = pipe.tokenizer(
[""], return_tensors="pt", padding="max_length", max_length=77
).input_ids.to("cuda")
uncond_prompt_embeds = pipe.text_encoder(uncond_input_ids)[0].expand(
*prompt_embeds.shape
)
# Затем прогоняем модель для пустых промптов
uncond_noise_pred = teacher_unet(x_t, t, uncond_prompt_embeds)
# Применяем CFG формулу и получаем итоговый предикт учителя
noise_pred = (
uncond_noise_pred.sample + guidance_scale * (cond_noise_pred.sample - uncond_noise_pred.sample)
)
# Получаем x_s из x_t
x_s = ddim_solver_step(noise_pred, x_t, t, s, scheduler)
return x_s
# Сразу зададим внутрь модель учителя и guidance_scale
get_xs_from_xt_with_teacher = functools.partial(
get_xs_from_xt_with_teacher,
teacher_unet=teacher_unet,
guidance_scale=7.5
)
Pseudo-Huber Loss¶
def pseudo_huber_loss(
x: torch.Tensor,
y: torch.Tensor,
c=0.001
):
diff = x - y
squared_norm = torch.norm(diff, p=2) ** 2
loss = torch.sqrt(squared_norm + c**2) - c
return loss
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32
# Добавляем новые LoRA адаптеры для CD модели
cm_unet.add_adapter("cd", lora_config)
cm_unet.set_adapter("cd")
# Пересоздаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
cd_loss = functools.partial(
cm_loss_template,
loss_fn=pseudo_huber_loss,
get_boundary_timesteps=get_zero_boundary_timesteps,
get_xs_from_xt=get_xs_from_xt_with_teacher
)
assert cm_unet.active_adapter == 'cd'
# num_grad_accum = 2 # обновляем параметры каждые 2 шага
# train_loop(cm_unet, pipe, train_dataloader, optimizer, cd_loss, num_grad_accum)
# cm_unet.save_pretrained('cm_unet')
cm_unet = PeftModel.from_pretrained(
model=unet,
model_id=f'{MODEL_DIR}/cd',
config=lora_config,
adapter_name="cd"
)
# Подставляем нашу новую обученную модель в пайплайн
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'cd'
generator = torch.Generator(device="cuda").manual_seed(1)
guidance_scale = 0
images = consistency_sampling(
pipe,
prompt=validation_prompts[0],
num_inference_steps=4,
generator=generator,
num_images_per_prompt=4,
guidance_scale=guidance_scale
)
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
guidance_scale = 0
for prompt in validation_prompts:
generator = torch.Generator(device="cuda").manual_seed(1)
images = consistency_sampling(
pipe,
prompt=prompt,
num_inference_steps=4,
generator=generator,
num_images_per_prompt=4,
guidance_scale=guidance_scale
)
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
Задание №7 (0.25 балла, сдается в контесте)¶
Ниже реализуйте функцию, которая для $K=4$ отрезков будет сопоставлять таймстепам соответствующие граничные точки.
def get_multi_boundary_timesteps(timesteps, num_boundaries=4, num_timesteps=1000):
"""
For a batch of timesteps, determine the corresponding boundary points.
params:
timesteps: torch.Tensor(batch_size, device='cuda')
returns:
boundary_timesteps: torch.Tensor(batch_size, device='cuda')
"""
boundary_timesteps = torch.zeros_like(timesteps)
if num_boundaries < 2:
return boundary_timesteps
boundary_points = torch.linspace(0, num_timesteps - 1, num_boundaries + 1).long().to(timesteps.device)
for i in range(len(boundary_points) - 1):
boundary_timesteps[
(boundary_points[i] <= timesteps) & (timesteps < boundary_points[i + 1])
] = boundary_points[i]
boundary_timesteps[timesteps >= boundary_points[-1]] = boundary_points[-2]
return boundary_timesteps
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32
cm_unet.add_adapter("multi-cd", lora_config)
cm_unet.set_adapter("multi-cd")
# Пересоздаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
multi_cd_loss = functools.partial(
cm_loss_template,
loss_fn=pseudo_huber_loss,
get_boundary_timesteps=get_multi_boundary_timesteps,
get_xs_from_xt=get_xs_from_xt_with_teacher
)
assert cm_unet.active_adapter == 'multi-cd'
# num_grad_accum = 2 # обновляем параметры каждые 2 шага
# train_loop(cm_unet, pipe, train_dataloader, optimizer, multi_cd_loss, num_grad_accum)
# cm_unet.save_pretrained('cm_unet')
cm_unet = PeftModel.from_pretrained(
model=unet,
#model_id=f'{MODEL_DIR}/multi-cd',
model_id='/kaggle/input/notebooke7f6989e08/cm_unet/multi-cd',
config=lora_config,
adapter_name="multi-cd"
)
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'multi-cd'
guidance_scale = 1
for prompt in validation_prompts:
generator = torch.Generator(device="cuda").manual_seed(2)
images = pipe(
prompt=prompt,
num_inference_steps=4,
generator=generator,
num_images_per_prompt=4,
guidance_scale=guidance_scale
).images
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
0%| | 0/4 [00:00<?, ?it/s]
cm_unet.push_to_hub(
"jmpleo/cv-week-2024",
token="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"
)
adapter_model.safetensors: 0%| | 0.00/135M [00:00<?, ?B/s]
adapter_model.safetensors: 0%| | 0.00/135M [00:00<?, ?B/s]
Upload 3 LFS files: 0%| | 0/3 [00:00<?, ?it/s]
adapter_model.safetensors: 0%| | 0.00/269M [00:00<?, ?B/s]
CommitInfo(commit_url='https://huggingface.co/jmpleo/cv-week-2024/commit/d4a323d8efe76da7e1b727c4c1376b86ad443880', commit_message='Upload model', commit_description='', oid='d4a323d8efe76da7e1b727c4c1376b86ad443880', pr_url=None, pr_revision=None, pr_num=None)
loaded_cm_unet = PeftModel.from_pretrained(
unet,
"jmpleo/cv-week-2024",
token="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",
subfolder='multi-cd',
adapter_name="multi-cd",
)
multi-cd/adapter_config.json: 0%| | 0.00/895 [00:00<?, ?B/s]
adapter_model.safetensors: 0%| | 0.00/135M [00:00<?, ?B/s]
pipe.unet = loaded_cm_unet.eval().to(torch.float16)
assert loaded_cm_unet.active_adapter == 'multi-cd'
guidance_scale = 1
for prompt in validation_prompts:
generator = torch.Generator(device="cuda").manual_seed(1)
images = pipe(
prompt=prompt,
num_inference_steps=4,
generator=generator,
num_images_per_prompt=4,
guidance_scale=guidance_scale
).images
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
0%| | 0/4 [00:00<?, ?it/s]


